{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyM3TaApvacHQONh9AbS8ZKW",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/sugarforever/LangChain-Tutorials/blob/main/LangChain_LLM_Math.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install langchain openai"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "f9iqJ8-RM_2Q",
        "outputId": "f4141238-2a64-4c24-a28d-bb6c8b17db19"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Collecting langchain\n",
            "  Downloading langchain-0.0.151-py3-none-any.whl (665 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m666.0/666.0 kB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting openai\n",
            "  Downloading openai-0.27.5-py3-none-any.whl (71 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.6/71.6 kB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: pydantic<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (1.10.7)\n",
            "Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.10/dist-packages (from langchain) (6.0)\n",
            "Collecting dataclasses-json<0.6.0,>=0.5.7\n",
            "  Downloading dataclasses_json-0.5.7-py3-none-any.whl (25 kB)\n",
            "Requirement already satisfied: SQLAlchemy<3,>1.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.0.10)\n",
            "Collecting async-timeout<5.0.0,>=4.0.0\n",
            "  Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n",
            "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (8.2.2)\n",
            "Requirement already satisfied: numpy<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (1.22.4)\n",
            "Collecting openapi-schema-pydantic<2.0,>=1.2\n",
            "  Downloading openapi_schema_pydantic-1.2.4-py3-none-any.whl (90 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m90.0/90.0 kB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tqdm>=4.48.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (4.65.0)\n",
            "Collecting aiohttp<4.0.0,>=3.8.3\n",
            "  Downloading aiohttp-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.27.1)\n",
            "Requirement already satisfied: numexpr<3.0.0,>=2.8.4 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.8.4)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.1.0)\n",
            "Collecting multidict<7.0,>=4.5\n",
            "  Downloading multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.5/114.5 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting frozenlist>=1.1.1\n",
            "  Downloading frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (149 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m149.6/149.6 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting aiosignal>=1.1.2\n",
            "  Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n",
            "Collecting yarl<2.0,>=1.0\n",
            "  Downloading yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (268 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m31.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (2.0.12)\n",
            "Collecting marshmallow-enum<2.0.0,>=1.5.1\n",
            "  Downloading marshmallow_enum-1.5.1-py2.py3-none-any.whl (4.2 kB)\n",
            "Collecting marshmallow<4.0.0,>=3.3.0\n",
            "  Downloading marshmallow-3.19.0-py3-none-any.whl (49 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.1/49.1 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting typing-inspect>=0.4.0\n",
            "  Downloading typing_inspect-0.8.0-py3-none-any.whl (8.7 kB)\n",
            "Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<2,>=1->langchain) (4.5.0)\n",
            "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (1.26.15)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (3.4)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (2022.12.7)\n",
            "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>1.3->langchain) (2.0.2)\n",
            "Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.10/dist-packages (from marshmallow<4.0.0,>=3.3.0->dataclasses-json<0.6.0,>=0.5.7->langchain) (23.1)\n",
            "Collecting mypy-extensions>=0.3.0\n",
            "  Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n",
            "Installing collected packages: mypy-extensions, multidict, marshmallow, frozenlist, async-timeout, yarl, typing-inspect, openapi-schema-pydantic, marshmallow-enum, aiosignal, dataclasses-json, aiohttp, openai, langchain\n",
            "Successfully installed aiohttp-3.8.4 aiosignal-1.3.1 async-timeout-4.0.2 dataclasses-json-0.5.7 frozenlist-1.3.3 langchain-0.0.151 marshmallow-3.19.0 marshmallow-enum-1.5.1 multidict-6.0.4 mypy-extensions-1.0.0 openai-0.27.5 openapi-schema-pydantic-1.2.4 typing-inspect-0.8.0 yarl-1.9.2\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain.agents import load_tools\n",
        "from langchain.agents import initialize_agent\n",
        "from langchain.agents import AgentType\n",
        "from langchain.llms import OpenAI\n",
        "from langchain.chat_models import ChatOpenAI\n",
        "from langchain.chains.conversation.memory import ConversationBufferWindowMemory"
      ],
      "metadata": {
        "id": "vWAHCTE7oa8F"
      },
      "execution_count": 23,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "OPENAI_API_KEY = 'your OpenAI API key here'"
      ],
      "metadata": {
        "id": "Ojv0fTD_oqxj"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, temperature=0, model_name=\"gpt-3.5-turbo\")"
      ],
      "metadata": {
        "id": "VVwq0-yHoczv"
      },
      "execution_count": 24,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain.tools import BaseTool\n",
        "\n",
        "class EvaluateMathExpression(BaseTool):\n",
        "    name = \"Math Evaluation\"\n",
        "    description = 'use this tool to evaluate a math expression.'\n",
        "\n",
        "    def _run(self, expr: str):\n",
        "        return eval(expr)\n",
        "    \n",
        "    def _arun(self, query: str):\n",
        "        raise NotImplementedError(\"Async operation not supported yet\")\n",
        "\n",
        "tools = [EvaluateMathExpression()]"
      ],
      "metadata": {
        "id": "qFPyDKY0ufLs"
      },
      "execution_count": 25,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "agent = initialize_agent(\n",
        "    agent='chat-conversational-react-description',\n",
        "    tools=tools,\n",
        "    llm=llm,\n",
        "    verbose=True,\n",
        "    max_iterations=3,\n",
        "    early_stopping_method='generate',\n",
        "    memory=ConversationBufferWindowMemory(\n",
        "        memory_key='chat_history',\n",
        "        k=5,\n",
        "        return_messages=True\n",
        "    )\n",
        ")"
      ],
      "metadata": {
        "id": "fkUISONKCMGQ"
      },
      "execution_count": 27,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for message in agent.agent.llm_chain.prompt.messages:\n",
        "  print(message)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tQXiJpQ1h5xc",
        "outputId": "fb413b11-6867-4555-c967-e09189ea04aa"
      },
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "prompt=PromptTemplate(input_variables=[], output_parser=None, partial_variables={}, template='Assistant is a large language model trained by OpenAI.\\n\\nAssistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.\\n\\nAssistant is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, Assistant is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics.\\n\\nOverall, Assistant is a powerful system that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist.', template_format='f-string', validate_template=True) additional_kwargs={}\n",
            "variable_name='chat_history'\n",
            "prompt=PromptTemplate(input_variables=['input'], output_parser=None, partial_variables={}, template='TOOLS\\n------\\nAssistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are:\\n\\n> Math Evaluation: use this tool to evaluate a math expression.\\n\\nRESPONSE FORMAT INSTRUCTIONS\\n----------------------------\\n\\nWhen responding to me, please output a response in one of two formats:\\n\\n**Option 1:**\\nUse this if you want the human to use a tool.\\nMarkdown code snippet formatted in the following schema:\\n\\n```json\\n{{\\n    \"action\": string \\\\ The action to take. Must be one of Math Evaluation\\n    \"action_input\": string \\\\ The input to the action\\n}}\\n```\\n\\n**Option #2:**\\nUse this if you want to respond directly to the human. Markdown code snippet formatted in the following schema:\\n\\n```json\\n{{\\n    \"action\": \"Final Answer\",\\n    \"action_input\": string \\\\ You should put what you want to return to use here\\n}}\\n```\\n\\nUSER\\'S INPUT\\n--------------------\\nHere is the user\\'s input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):\\n\\n{input}', template_format='f-string', validate_template=True) additional_kwargs={}\n",
            "variable_name='agent_scratchpad'\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "agent(f\"What is 2 * 2 * 0.13 - 1.001?\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wJPk8RebvGC3",
        "outputId": "0f7f012a-eb64-44d3-8314-c77f9d3318e9"
      },
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "\n",
            "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
            "\u001b[32;1m\u001b[1;3m{\n",
            "    \"action\": \"Final Answer\",\n",
            "    \"action_input\": \"-0.341\"\n",
            "}\u001b[0m\n",
            "\n",
            "\u001b[1m> Finished chain.\u001b[0m\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'input': 'What is 2 * 2 * 0.13 - 1.001?',\n",
              " 'chat_history': [],\n",
              " 'output': '-0.341'}"
            ]
          },
          "metadata": {},
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "2 * 2 * 0.13 - 1.001"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "KvC-1ZB9hURY",
        "outputId": "9797b9d0-3223-4597-d749-e93a752c897a"
      },
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "-0.48099999999999987"
            ]
          },
          "metadata": {},
          "execution_count": 30
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain.agents.conversational_chat.prompt import (PREFIX)\n",
        "system_message = PREFIX + \"\\n\" + '''\n",
        "Unfortunately, Assistant is terrible at maths. Assistant should always refers to available tools and never try to answer math questions by itself\n",
        "'''"
      ],
      "metadata": {
        "id": "clEarofMhhJu"
      },
      "execution_count": 31,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "new_prompt = agent.agent.create_prompt(\n",
        "    system_message=system_message,\n",
        "    tools=tools\n",
        ")\n",
        "\n",
        "agent.agent.llm_chain.prompt = new_prompt"
      ],
      "metadata": {
        "id": "ZARVE-emhnIN"
      },
      "execution_count": 32,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "agent(f\"What is 2 * 2 * 0.13 - 1.001?\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nm1_Im67hroZ",
        "outputId": "cd7abc79-8be3-470d-d5a2-70f1c8b2623e"
      },
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "\n",
            "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
            "\u001b[32;1m\u001b[1;3m{\n",
            "    \"action\": \"Math Evaluation\",\n",
            "    \"action_input\": \"2 * 2 * 0.13 - 1.001\"\n",
            "}\u001b[0m\n",
            "Observation: \u001b[36;1m\u001b[1;3m-0.48099999999999987\u001b[0m\n",
            "Thought:\u001b[32;1m\u001b[1;3m{\n",
            "    \"action\": \"Final Answer\",\n",
            "    \"action_input\": \"-0.481\"\n",
            "}\u001b[0m\n",
            "\n",
            "\u001b[1m> Finished chain.\u001b[0m\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'input': 'What is 2 * 2 * 0.13 - 1.001?',\n",
              " 'chat_history': [HumanMessage(content='What is 2 * 2 * 0.13 - 1.001?', additional_kwargs={}),\n",
              "  AIMessage(content='-0.341', additional_kwargs={})],\n",
              " 'output': '-0.481'}"
            ]
          },
          "metadata": {},
          "execution_count": 33
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "agent.memory.buffer"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6g0ri68KEC0q",
        "outputId": "6d348ff3-2e6e-4f7e-bd0c-8d637459c1e9"
      },
      "execution_count": 34,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[HumanMessage(content='What is 2 * 2 * 0.13 - 1.001?', additional_kwargs={}),\n",
              " AIMessage(content='-0.341', additional_kwargs={}),\n",
              " HumanMessage(content='What is 2 * 2 * 0.13 - 1.001?', additional_kwargs={}),\n",
              " AIMessage(content='-0.481', additional_kwargs={})]"
            ]
          },
          "metadata": {},
          "execution_count": 34
        }
      ]
    }
  ]
}